import torch
import math

max_len = 10
valid_lens = torch.arange(0, 11, 1)
mask = torch.arange((max_len), dtype=torch.float32)[None, :] > valid_lens[:, None]
mask = mask.float()
mask[mask == 1] = float(-math.inf)
print(mask)